{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import collections\n",
    "import os\n",
    "import random\n",
    "from pathlib import Path\n",
    "import logging\n",
    "import shutil\n",
    "import time\n",
    "from packaging import version\n",
    "from collections import defaultdict\n",
    "\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import gzip\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "import torch.distributed as dist\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "from src.param import parse_args\n",
    "from src.utils import LossMeter\n",
    "from src.dist_utils import reduce_dict\n",
    "from transformers import T5Tokenizer, T5TokenizerFast\n",
    "from src.tokenization import P5Tokenizer, P5TokenizerFast\n",
    "from src.pretrain_model import P5Pretraining\n",
    "\n",
    "_use_native_amp = False\n",
    "_use_apex = False\n",
    "\n",
    "# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex\n",
    "if version.parse(torch.__version__) < version.parse(\"1.6\"):\n",
    "    from transormers.file_utils import is_apex_available\n",
    "    if is_apex_available():\n",
    "        from apex import amp\n",
    "    _use_apex = True\n",
    "else:\n",
    "    _use_native_amp = True\n",
    "    from torch.cuda.amp import autocast\n",
    "\n",
    "from src.trainer_base import TrainerBase\n",
    "\n",
    "import pickle\n",
    "\n",
    "def load_pickle(filename):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        return pickle.load(f)\n",
    "\n",
    "\n",
    "def save_pickle(data, filename):\n",
    "    with open(filename, \"wb\") as f:\n",
    "        pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "        \n",
    "import json\n",
    "\n",
    "def load_json(file_path):\n",
    "    with open(file_path, \"r\") as f:\n",
    "        return json.load(f)\n",
    "    \n",
    "def ReadLineFromFile(path):\n",
    "    lines = []\n",
    "    with open(path,'r') as fd:\n",
    "        for line in fd:\n",
    "            lines.append(line.rstrip('\\n'))\n",
    "    return lines\n",
    "\n",
    "def parse(path):\n",
    "    g = gzip.open(path, 'r')\n",
    "    for l in g:\n",
    "        yield eval(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['rating_loss', 'sequential_loss', 'explanation_loss', 'review_loss', 'traditional_loss']\n",
      "Process Launching at GPU 0\n",
      "{'distributed': False, 'multiGPU': True, 'fp16': True, 'train': 'beauty', 'valid': 'beauty', 'test': 'beauty', 'batch_size': 16, 'optim': 'adamw', 'warmup_ratio': 0.05, 'lr': 0.001, 'num_workers': 4, 'clip_grad_norm': 1.0, 'losses': 'rating,sequential,explanation,review,traditional', 'backbone': 't5-small', 'output': 'snap/beauty-small', 'epoch': 10, 'local_rank': 0, 'comment': '', 'train_topk': -1, 'valid_topk': -1, 'dropout': 0.1, 'tokenizer': 'p5', 'max_text_length': 512, 'do_lower_case': False, 'word_mask_rate': 0.15, 'gen_max_length': 64, 'weight_decay': 0.01, 'adam_eps': 1e-06, 'gradient_accumulation_steps': 1, 'seed': 2022, 'whole_word_embed': True, 'world_size': 4, 'LOSSES_NAME': ['rating_loss', 'sequential_loss', 'explanation_loss', 'review_loss', 'traditional_loss', 'total_loss'], 'gpu': 0, 'rank': 0}\n"
     ]
    }
   ],
   "source": [
    "class DotDict(dict):\n",
    "    def __init__(self, **kwds):\n",
    "        self.update(kwds)\n",
    "        self.__dict__ = self\n",
    "        \n",
    "args = DotDict()\n",
    "\n",
    "args.distributed = False\n",
    "args.multiGPU = True\n",
    "args.fp16 = True\n",
    "args.train = \"beauty\"\n",
    "args.valid = \"beauty\"\n",
    "args.test = \"beauty\"\n",
    "args.batch_size = 16\n",
    "args.optim = 'adamw' \n",
    "args.warmup_ratio = 0.05\n",
    "args.lr = 1e-3\n",
    "args.num_workers = 4\n",
    "args.clip_grad_norm = 1.0\n",
    "args.losses = 'rating,sequential,explanation,review,traditional'\n",
    "args.backbone = 't5-small' # small or base\n",
    "args.output = 'snap/beauty-small'\n",
    "args.epoch = 10\n",
    "args.local_rank = 0\n",
    "\n",
    "args.comment = ''\n",
    "args.train_topk = -1\n",
    "args.valid_topk = -1\n",
    "args.dropout = 0.1\n",
    "\n",
    "args.tokenizer = 'p5'\n",
    "args.max_text_length = 512\n",
    "args.do_lower_case = False\n",
    "args.word_mask_rate = 0.15\n",
    "args.gen_max_length = 64\n",
    "\n",
    "args.weight_decay = 0.01\n",
    "args.adam_eps = 1e-6\n",
    "args.gradient_accumulation_steps = 1\n",
    "\n",
    "'''\n",
    "Set seeds\n",
    "'''\n",
    "args.seed = 2022\n",
    "torch.manual_seed(args.seed)\n",
    "random.seed(args.seed)\n",
    "np.random.seed(args.seed)\n",
    "\n",
    "'''\n",
    "Whole word embedding\n",
    "'''\n",
    "args.whole_word_embed = True\n",
    "\n",
    "cudnn.benchmark = True\n",
    "ngpus_per_node = torch.cuda.device_count()\n",
    "args.world_size = ngpus_per_node\n",
    "\n",
    "LOSSES_NAME = [f'{name}_loss' for name in args.losses.split(',')]\n",
    "if args.local_rank in [0, -1]:\n",
    "    print(LOSSES_NAME)\n",
    "LOSSES_NAME.append('total_loss') # total loss\n",
    "\n",
    "args.LOSSES_NAME = LOSSES_NAME\n",
    "\n",
    "gpu = 0 # Change GPU ID\n",
    "args.gpu = gpu\n",
    "args.rank = gpu\n",
    "print(f'Process Launching at GPU {gpu}')\n",
    "\n",
    "torch.cuda.set_device('cuda:{}'.format(gpu))\n",
    "\n",
    "comments = []\n",
    "dsets = []\n",
    "if 'toys' in args.train:\n",
    "    dsets.append('toys')\n",
    "if 'beauty' in args.train:\n",
    "    dsets.append('beauty')\n",
    "if 'sports' in args.train:\n",
    "    dsets.append('sports')\n",
    "comments.append(''.join(dsets))\n",
    "if args.backbone:\n",
    "    comments.append(args.backbone)\n",
    "comments.append(''.join(args.losses.split(',')))\n",
    "if args.comment != '':\n",
    "    comments.append(args.comment)\n",
    "comment = '_'.join(comments)\n",
    "\n",
    "if args.local_rank in [0, -1]:\n",
    "    print(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_config(args):\n",
    "    from transformers import T5Config, BartConfig\n",
    "\n",
    "    if 't5' in args.backbone:\n",
    "        config_class = T5Config\n",
    "    else:\n",
    "        return None\n",
    "\n",
    "    config = config_class.from_pretrained(args.backbone)\n",
    "    config.dropout_rate = args.dropout\n",
    "    config.dropout = args.dropout\n",
    "    config.attention_dropout = args.dropout\n",
    "    config.activation_dropout = args.dropout\n",
    "    config.losses = args.losses\n",
    "\n",
    "    return config\n",
    "\n",
    "\n",
    "def create_tokenizer(args):\n",
    "    from transformers import T5Tokenizer, T5TokenizerFast\n",
    "    from src.tokenization import P5Tokenizer, P5TokenizerFast\n",
    "\n",
    "    if 'p5' in args.tokenizer:\n",
    "        tokenizer_class = P5Tokenizer\n",
    "\n",
    "    tokenizer_name = args.backbone\n",
    "    \n",
    "    tokenizer = tokenizer_class.from_pretrained(\n",
    "        tokenizer_name,\n",
    "        max_length=args.max_text_length,\n",
    "        do_lower_case=args.do_lower_case,\n",
    "    )\n",
    "\n",
    "    print(tokenizer_class, tokenizer_name)\n",
    "    \n",
    "    return tokenizer\n",
    "\n",
    "\n",
    "def create_model(model_class, config=None):\n",
    "    print(f'Building Model at GPU {args.gpu}')\n",
    "\n",
    "    model_name = args.backbone\n",
    "\n",
    "    model = model_class.from_pretrained(\n",
    "        model_name,\n",
    "        config=config\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'src.tokenization.P5Tokenizer'> t5-small\n",
      "Building Model at GPU 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of P5Pretraining were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.whole_word_embeddings.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "config = create_config(args)\n",
    "\n",
    "if args.tokenizer is None:\n",
    "    args.tokenizer = args.backbone\n",
    "    \n",
    "tokenizer = create_tokenizer(args)\n",
    "\n",
    "model_class = P5Pretraining\n",
    "model = create_model(model_class, config)\n",
    "\n",
    "model = model.cuda()\n",
    "\n",
    "if 'p5' in args.tokenizer:\n",
    "    model.resize_token_embeddings(tokenizer.vocab_size)\n",
    "    \n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded from  ../snap/beauty-small.pth\n",
      "<All keys matched successfully>\n"
     ]
    }
   ],
   "source": [
    "args.load = \"../snap/beauty-small.pth\"\n",
    "\n",
    "# Load Checkpoint\n",
    "from src.utils import load_state_dict, LossMeter, set_global_logging_level\n",
    "from pprint import pprint\n",
    "\n",
    "def load_checkpoint(ckpt_path):\n",
    "    state_dict = load_state_dict(ckpt_path, 'cpu')\n",
    "    results = model.load_state_dict(state_dict, strict=False)\n",
    "    print('Model loaded from ', ckpt_path)\n",
    "    pprint(results)\n",
    "\n",
    "ckpt_path = args.load\n",
    "load_checkpoint(ckpt_path)\n",
    "\n",
    "from src.all_amazon_templates import all_tasks as task_templates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Check Test Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_splits = load_pickle('../data/beauty/rating_splits_augmented.pkl')\n",
    "test_review_data = data_splits['test']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "19850"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_review_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'reviewerID': 'A2QKXW3LDQ66P5',\n",
       " 'asin': 'B005X2F7KI',\n",
       " 'reviewerName': 'stephanie',\n",
       " 'helpful': [5, 6],\n",
       " 'reviewText': 'Absolutely great product.  I bought this for my fourteen year old niece for Christmas and of course I had to try it out, then I tried another one, and another one and another one.  So much fun!  I even contemplated keeping a few for myself!',\n",
       " 'overall': 5.0,\n",
       " 'summary': 'Perfect!',\n",
       " 'unixReviewTime': 1352937600,\n",
       " 'reviewTime': '11 15, 2012',\n",
       " 'explanation': 'Absolutely great product',\n",
       " 'feature': 'product'}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_review_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22363\n",
      "12101\n"
     ]
    }
   ],
   "source": [
    "data_maps = load_json(os.path.join('../data', 'beauty', 'datamaps.json'))\n",
    "print(len(data_maps['user2id'])) # number of users\n",
    "print(len(data_maps['item2id'])) # number of items"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test P5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, Dataset, Sampler\n",
    "from src.pretrain_data import get_loader\n",
    "from evaluate.utils import rouge_score, bleu_score, unique_sentence_percent, root_mean_square_error, mean_absolute_error, feature_detect, feature_matching_ratio, feature_coverage_ratio, feature_diversity\n",
    "from evaluate.metrics4rec import evaluate_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation - Rating"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1241it [01:02, 19.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  1.2989\n",
      "MAE  0.8473\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'rating': ['1-10'] # or '1-6'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), float(p)) for (r, p) in zip(gt_ratings, pred_ratings) if p in [str(i/10.0) for i in list(range(10, 50))]]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1241it [00:59, 20.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  1.3128\n",
      "MAE  0.8428\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'rating': ['1-6'] # or '1-10'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), float(p)) for (r, p) in zip(gt_ratings, pred_ratings) if p in [str(i/10.0) for i in list(range(10, 50))]]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Sequential"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/global/homes/z/zw241/.conda/envs/pt-1.10/lib/python3.9/site-packages/transformers/generation_utils.py:1632: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  next_indices = next_tokens // vocab_size\n",
      "1398it [16:47,  1.39it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.0358\t0.0490\t0.0490\t0.0098\t0.0315\t0.0315\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.0409\t0.0646\t0.0646\t0.0065\t0.0336\t0.0336\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.0409\\t0.0646\\t0.0646\\t0.0065\\t0.0336\\t0.0336',\n",
       " {'ndcg': 0.040918043981129575,\n",
       "  'map': 0.03361417917492684,\n",
       "  'recall': 0.06457094307561598,\n",
       "  'precision': 0.0064570943075614225,\n",
       "  'mrr': 0.03361417917492684,\n",
       "  'hit': 0.06457094307561598})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'sequential': ['2-13'] # or '2-3'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1398it [17:23,  1.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.0370\t0.0503\t0.0503\t0.0101\t0.0326\t0.0326\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.0421\t0.0659\t0.0659\t0.0066\t0.0347\t0.0347\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.0421\\t0.0659\\t0.0659\\t0.0066\\t0.0347\\t0.0347',\n",
       " {'ndcg': 0.042061379048683484,\n",
       "  'map': 0.03469676740704778,\n",
       "  'recall': 0.06586772794347806,\n",
       "  'precision': 0.006586772794347624,\n",
       "  'mrr': 0.03469676740704778,\n",
       "  'hit': 0.06586772794347806})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'sequential': ['2-3'] # or '2-13'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluation - Explanation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "839it [07:53,  1.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 18.4199\n",
      "BLEU-4  2.6987\n",
      "rouge_1/f_score 25.6258\n",
      "rouge_1/r_score 25.1873\n",
      "rouge_1/p_score 32.0580\n",
      "rouge_2/f_score  5.5327\n",
      "rouge_2/r_score  6.1057\n",
      "rouge_2/p_score  6.6030\n",
      "rouge_l/f_score 18.6548\n",
      "rouge_l/r_score 22.9544\n",
      "rouge_l/p_score 23.5079\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-12'] # or '3-9' or '3-3'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=9,\n",
    "                num_beams=12,\n",
    "                num_return_sequences=1,\n",
    "                num_beam_groups=3,\n",
    "                repetition_penalty=0.7\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "839it [09:12,  1.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 19.9696\n",
      "BLEU-4  2.7584\n",
      "rouge_1/f_score 24.9963\n",
      "rouge_1/r_score 26.8389\n",
      "rouge_1/p_score 28.0289\n",
      "rouge_2/f_score  5.1078\n",
      "rouge_2/r_score  6.1455\n",
      "rouge_2/p_score  5.5334\n",
      "rouge_l/f_score 18.4491\n",
      "rouge_l/r_score 23.0867\n",
      "rouge_l/p_score 22.6628\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-9'] # or '3-12' or '3-3'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=10,\n",
    "                num_beams=12,\n",
    "                num_return_sequences=1,\n",
    "                num_beam_groups=3\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "839\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "839it [03:27,  4.03it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-1 15.5223\n",
      "BLEU-4  0.9783\n",
      "rouge_1/f_score 17.0412\n",
      "rouge_1/r_score 18.2074\n",
      "rouge_1/p_score 18.9502\n",
      "rouge_2/f_score  1.8962\n",
      "rouge_2/r_score  2.3611\n",
      "rouge_2/p_score  2.0044\n",
      "rouge_l/f_score 12.1709\n",
      "rouge_l/r_score 15.3009\n",
      "rouge_l/p_score 14.4041\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'explanation': ['3-3'] # or '3-12' or '3-9'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                min_length=10\n",
    "        )\n",
    "        results = model.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU1 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=1, smooth=False)\n",
    "BLEU4 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=4, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-1 {:7.4f}'.format(BLEU1))\n",
    "print('BLEU-4 {:7.4f}'.format(BLEU4))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Review\n",
    "\n",
    "Since T0 & GPT-2 checkpoints hosted on Hugging Face platform are slow to conduct inference, we only perform evaluation on the first 800 instances for prompts in Task Family 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [00:02, 19.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  0.6262\n",
      "MAE  0.3113\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'review': ['4-4'] # or '4-2'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    if i > 50:\n",
    "        break\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), round(float(p))) for (r, p) in zip(gt_ratings, pred_ratings)]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [00:02, 19.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE  0.6233\n",
      "MAE  0.3051\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'review': ['4-2'] # or '4-4'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "gt_ratings = []\n",
    "pred_ratings = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    if i > 50:\n",
    "        break\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        gt_ratings.extend(batch['target_text'])\n",
    "        pred_ratings.extend(results)\n",
    "        \n",
    "predicted_rating = [(float(r), round(float(p))) for (r, p) in zip(gt_ratings, pred_ratings)]\n",
    "RMSE = root_mean_square_error(predicted_rating, 5.0, 1.0)\n",
    "print('RMSE {:7.4f}'.format(RMSE))\n",
    "MAE = mean_absolute_error(predicted_rating, 5.0, 1.0)\n",
    "print('MAE {:7.4f}'.format(MAE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1241\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "51it [00:06,  7.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BLEU-2  2.1225\n",
      "rouge_1/f_score  8.4205\n",
      "rouge_1/r_score  7.5503\n",
      "rouge_1/p_score 11.1520\n",
      "rouge_2/f_score  1.6676\n",
      "rouge_2/r_score  1.5984\n",
      "rouge_2/p_score  1.9812\n",
      "rouge_l/f_score  7.5476\n",
      "rouge_l/r_score  7.5304\n",
      "rouge_l/p_score 11.1520\n"
     ]
    }
   ],
   "source": [
    "test_task_list = {'review': ['4-1']\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "tokens_predict = []\n",
    "tokens_test = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    if i > 50:\n",
    "        break\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        tokens_predict.extend(results) \n",
    "        tokens_test.extend(batch['target_text'])\n",
    "        \n",
    "new_tokens_predict = [l.split() for l in tokens_predict]\n",
    "new_tokens_test = [ll.split() for ll in tokens_test]\n",
    "BLEU2 = bleu_score(new_tokens_test, new_tokens_predict, n_gram=2, smooth=False)\n",
    "ROUGE = rouge_score(tokens_test, tokens_predict)\n",
    "\n",
    "print('BLEU-2 {:7.4f}'.format(BLEU2))\n",
    "for (k, v) in ROUGE.items():\n",
    "    print('{} {:7.4f}'.format(k, v))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Evaluation - Traditional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/global/homes/z/zw241/.conda/envs/pt-1.10/lib/python3.9/site-packages/transformers/generation_utils.py:1632: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n",
      "  next_indices = next_tokens // vocab_size\n",
      "1398it [17:55,  1.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@1\tRec@1\tHits@1\tPrec@1\tMAP@1\tMRR@1\n",
      "0.0598\t0.0598\t0.0598\t0.0598\t0.0598\t0.0598\n",
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.1101\t0.1589\t0.1589\t0.0318\t0.0940\t0.0940\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.1340\t0.2332\t0.2332\t0.0233\t0.1039\t0.1039\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.1340\\t0.2332\\t0.2332\\t0.0233\\t0.1039\\t0.1039',\n",
       " {'ndcg': 0.13398695780876257,\n",
       "  'map': 0.10386263733533777,\n",
       "  'recall': 0.23315297589768816,\n",
       "  'precision': 0.02331529758977105,\n",
       "  'mrr': 0.10386263733533777,\n",
       "  'hit': 0.23315297589768816})"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'traditional': ['5-8']  # or '5-5'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 1)\n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Data sources:  ['beauty']\n",
      "compute_datum_info\n",
      "1398\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1398it [17:42,  1.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "NDCG@1\tRec@1\tHits@1\tPrec@1\tMAP@1\tMRR@1\n",
      "0.0595\t0.0595\t0.0595\t0.0595\t0.0595\t0.0595\n",
      "\n",
      "NDCG@5\tRec@5\tHits@5\tPrec@5\tMAP@5\tMRR@5\n",
      "0.1112\t0.1606\t0.1606\t0.0321\t0.0949\t0.0949\n",
      "\n",
      "NDCG@10\tRec@10\tHits@10\tPrec@10\tMAP@10\tMRR@10\n",
      "0.1352\t0.2352\t0.2352\t0.0235\t0.1047\t0.1047\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('\\nNDCG@10\\tRec@10\\tHits@10\\tPrec@10\\tMAP@10\\tMRR@10\\n0.1352\\t0.2352\\t0.2352\\t0.0235\\t0.1047\\t0.1047',\n",
       " {'ndcg': 0.13516935746926673,\n",
       "  'map': 0.10474829455400039,\n",
       "  'recall': 0.23520994499843492,\n",
       "  'precision': 0.023520994499845772,\n",
       "  'mrr': 0.10474829455400039,\n",
       "  'hit': 0.23520994499843492})"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_task_list = {'traditional': ['5-5']  # or '5-8'\n",
    "}\n",
    "test_sample_numbers = {'rating': 1, 'sequential': (1, 1, 1), 'explanation': 1, 'review': 1, 'traditional': (1, 1)}\n",
    "\n",
    "zeroshot_test_loader = get_loader(\n",
    "        args,\n",
    "        test_task_list,\n",
    "        test_sample_numbers,\n",
    "        split=args.test, \n",
    "        mode='test', \n",
    "        batch_size=args.batch_size,\n",
    "        workers=args.num_workers,\n",
    "        distributed=args.distributed\n",
    ")\n",
    "print(len(zeroshot_test_loader))\n",
    "\n",
    "all_info = []\n",
    "for i, batch in tqdm(enumerate(zeroshot_test_loader)):\n",
    "    with torch.no_grad():\n",
    "        results = model.generate_step(batch)\n",
    "        beam_outputs = model.generate(\n",
    "                batch['input_ids'].to('cuda'), \n",
    "                max_length=50, \n",
    "                num_beams=20,\n",
    "                no_repeat_ngram_size=0, \n",
    "                num_return_sequences=20,\n",
    "                early_stopping=True\n",
    "        )\n",
    "        generated_sents = model.tokenizer.batch_decode(beam_outputs, skip_special_tokens=True)\n",
    "        for j, item in enumerate(zip(results, batch['target_text'], batch['source_text'])):\n",
    "            new_info = {}\n",
    "            new_info['target_item'] = item[1]\n",
    "            new_info['gen_item_list'] = generated_sents[j*20: (j+1)*20]\n",
    "            all_info.append(new_info)\n",
    "            \n",
    "gt = {}\n",
    "ui_scores = {}\n",
    "for i, info in enumerate(all_info):\n",
    "    gt[i] = [int(info['target_item'])]\n",
    "    pred_dict = {}\n",
    "    for j in range(len(info['gen_item_list'])):\n",
    "        try:\n",
    "            pred_dict[int(info['gen_item_list'][j])] = -(j+1)\n",
    "        except:\n",
    "            pass\n",
    "    ui_scores[i] = pred_dict\n",
    "    \n",
    "evaluate_all(ui_scores, gt, 1)\n",
    "evaluate_all(ui_scores, gt, 5)\n",
    "evaluate_all(ui_scores, gt, 10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt-1.10",
   "language": "python",
   "name": "pt-1.10"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
